import torch
import numpy as np

from torch.utils.data import Dataset
from .cache import getCt, getCtSampleSize
from utils.candidate import getCandidateInfoDict, getCandidateInfoList
from utils.logconf import logging

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

# Updating the dataset for segmentation
#
# Our model expects input and will produce output of a different form than we
# had previously. Our previous dataset produced 3D data, but we need to produce
# 2D data now.
#
# The original U-Net implementation did not use padded convolutions, which
# means while the output segmentation map was smaller than the input, every
# pixel of that output had a fully populated receptive field. None of the input
# pixels that fed into the determination of that output pixel were padded,
# fabricated, or otherwise incomplete. Thus the output of the original U-Net
# will tile perfectly, so it can be used with images of any size. (except at
# the edges of the input image, where some context will be missing by definition).
#
# There are two problems with us taking the same pixel-perfect approach for our
# problem. The first is related to the interaction between convolution and
# downsampling, and the second is related to the nature of our data being
# three-dimensional.
#
# U-Net has very specific input size requirements
#
# The first issue is that the sizes of the input and output patches for U-Net
# are very specific. We will address this issue by setting the padding flag of
# the U-Net constructor to True. This will mean we can use input images of any
# size, and we will get output of the same size. We may lose some fidelity near
# the edges of the image, since the receptive field of pixels located there
# will include regions that have been artificially padded, but that's a
# compromise we decide to live with.
#
# U-Net trade-offs for 3D vs. 2D data
#
# The second issue is that our 3D data doesn't line up exactly with U-Net's 2D
# expected input. Simply taking our 512x512x128 image and feeding it into a
# converted-to-3D U-Net class won't work, because we'll exhaust our GPU memory.
#
# As anticipated, instead of trying to do things in 3D, we're going to treat
# each slice as a 2D segmentation problem and cheat our way around the issue of
# context in the third dimension by providing neighboring slices as separate
# channels.
#
# We lose the direct spatial relationship between slices when represented as
# channels, as all channels will be linearly combined by the convolution kernels
# with no notion of them being one or two slices away, above or below. We also
# lose the wider receptive field in the depth dimension that would come from a
# true 3D segmentation with downsampling. Since CT slices are often thicker than
# the resolution in rows and columns, we do get a somewhat wider view than it
# seems at first, and this should be enough, considering that nodules typically
# span a limited number of slices.

# Implementing Luna2dSegmentationDataset
#
# We will have two classes: one acting as a general base class suitable for
# validation data, and one subclassing the base for the trainingset, with
# randomization and a cropped sample. It actually simplifies the logic of
# selecting randomized training samples and the like, because our training data
# will look significantly different from our validation data!
#
# The data that we produce will be two-dimensional CT slices with multiple
# channels. The extra channels will hold adjacent slices of CT. Each slice of
# CT scan can be thought of as a 2D grayscale image.
#
# For the input to our classification model, we treated those slices as a 3D
# array of data and used 3D convolutions to process each sample. For our
# segmentation model, we are going to instead treat each slice as a single
# channel, and produce a multichannel 2D image. Doing so will mean that we are
# treating each slice of CT scan as if it was a color channel of an RGB image.
# Each input slice of the CT will get stacked together and consumed just like
# any other 2D image. The channels of our stacked CT image won't correspond to
# colors, but nothing about 2D convolutions requires the input channels to be
# colors, so it works out fine.
#
# For validation, we'll need to produce one sample per slice of CT that has an
# entry in the positive mask, for each validation CT we have. Since different
# CT scans can have different slice counts, we're going to introduce a new
# function that caches the size of each CT scan and its positive mask to disk.
# We need this to be able to quickly construct the full size of a validation
# set without having to load each CT at Dataset initialization.


class Luna2dSegmentationDataset(Dataset):
    def __init__(
        self,
        val_stride=0,
        isValSet_bool=None,
        series_uid=None,
        contextSlices_count=3,
        fullCt_bool=False,
    ):
        self.contextSlices_count = contextSlices_count
        self.fullCt_bool = fullCt_bool

        if series_uid:
            self.series_list = [series_uid]
        else:
            self.series_list = sorted(getCandidateInfoDict().keys())

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.series_list = self.series_list[::val_stride]
            assert self.series_list
        elif val_stride > 0:
            del self.series_list[::val_stride]
            assert self.series_list

        self.sample_list = []
        for series_uid in self.series_list:
            index_count, positive_indexes = getCtSampleSize(series_uid)

            # Speaking of validation, we're going to have two different modes
            # we can validate our training with. First, when fullCt_bool is
            # True, we will use every slice in the CT for our dataset. This
            # will be useful when we're evaluating end-to-end performance,
            # since we need to pretend that we're starting off with no prior
            # information about the CT. We'll use the second mode for validation
            # during training, which is when we're limiting ourselves to only
            # the CT slices that have a positive mask present.
            if self.fullCt_bool:
                self.sample_list += [
                    (series_uid, slice_ndx) for slice_ndx in range(index_count)
                ]
            else:
                self.sample_list += [
                    (series_uid, slice_ndx) for slice_ndx in positive_indexes
                ]

        self.candidateInfo_list = getCandidateInfoList()

        series_set = set(self.series_list)
        self.candidateInfo_list = [
            cit for cit in self.candidateInfo_list if cit.series_uid in series_set
        ]

        # For the data balancing yet to come, we want a list of actual nodules.
        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]

        log.info(
            "{!r}: {} {} series, {} slices, {} nodules".format(
                self,
                len(self.series_list),
                {None: "general", True: "validation", False: "training"}[isValSet_bool],
                len(self.sample_list),
                len(self.pos_list),
            )
        )

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, ndx):
        series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
        return self.getitem_fullSlice(series_uid, slice_ndx)

    def getitem_fullSlice(self, series_uid, slice_ndx):
        ct = getCt(series_uid)
        ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))

        start_ndx = slice_ndx - self.contextSlices_count
        end_ndx = slice_ndx + self.contextSlices_count + 1
        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
            # When we reach beyond the bounds, we duplicate the first or last slice
            context_ndx = max(context_ndx, 0)
            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
        # The upper bound nukes any weird hotspots and clamps bone down
        ct_t.clamp_(-1000, 1000)

        pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)

        return ct_t, pos_t, ct.series_uid, slice_ndx
